Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[py-tx] Align tx match --hash format with the output of tx hash #1289

Closed
wants to merge 5 commits into from
Prev Previous commit
Next Next commit
Fix for Issue 1188. This diff allows for the usage of the hash format…
… produced by the hash command to be used in the match command (i.e. prefixed with 'pdq'). In addition, allows for multiple hashes in a single file of varying signal types. Will throw exceptions when you specify (-S) a signal type, and provide hashes of another type, as well as when there are multiple signal type prefixes with a mixture of no prefixes -- i.e. when we are unable to 'infer' the signal type from the hash
  • Loading branch information
Sam Freeman committed Mar 28, 2023
commit df11f53cd013edf8f6ad4ec22e9ae77dc8b3ec26
120 changes: 108 additions & 12 deletions python-threatexchange/threatexchange/cli/match_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,19 +167,19 @@ def execute(self, settings: CLISettings) -> None:
if self.as_hashes:
types = (BytesHasher, TextHasher, FileHasher)
signal_types = [s for s in signal_types if issubclass(s, types)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(In reference to below comments) This list of signal types have already been narrowed down the valid set given the input.

It might be that we should consider moving --hashes to a toplevel argument (tx match photo|video|hash)` and allowing cross matching might make more sense, but now we've exceeded the scope of the pr.

if self.as_hashes and len(signal_types) > 1:
raise CommandError(
f"Error: '{self.content_type.get_name()}' supports more than one SignalType."
" for '--hashes' also use '--only-signal' to specify one of "
f"{[s.get_name() for s in signal_types]}",
2,
)

logging.info(
"Signal types that apply: %s",
", ".join(s.get_name() for s in signal_types) or "None!",
)

if self.as_hashes:
hashes_grouped_by_prefix = dict()
# Infer the signal types from the prefixes (None is used as key for hashes with no prefix)
for path in self.files:
_group_hashes_by_prefix(path, settings, hashes_grouped_by_prefix)
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
# Validate the SignalType and append the None prefixes to the correct SignalType
self.validate_hashes_signal_type(hashes_grouped_by_prefix, signal_types)
signal_types = list(hashes_grouped_by_prefix.keys())
indices: t.List[t.Tuple[t.Type[SignalType], SignalTypeIndex]] = []
for s_type in signal_types:
index = settings.index.load(s_type)
Expand All @@ -196,11 +196,14 @@ def execute(self, settings: CLISettings) -> None:
for s_type, index in indices:
seen = set() # TODO - maybe take the highest certainty?
if self.as_hashes:
results = _match_hashes(path, s_type, index)
results = _match_hashes(
hashes_grouped_by_prefix[s_type], s_type, index
)
else:
results = _match_file(path, s_type, index)

for r in results:
# TODO Improve visualisation of a single multiple hash query
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = r.metadata
for collab, fetched_data in metadatas:
if not self.all and collab in seen:
Expand All @@ -215,6 +218,52 @@ def execute(self, settings: CLISettings) -> None:
fetched_data,
)

def validate_hashes_signal_type(
self,
hashes_grouped_by_prefix: t.Dict[t.Optional[SignalType], t.Set[str]],
signal_types: t.List[t.Type[SignalType]],
) -> bool:
if (
len(hashes_grouped_by_prefix) > 2
and None in hashes_grouped_by_prefix.keys()
):
raise CommandError(
f"Error: Provided more than one SignalType and some hashes are missing a prefix",
2,
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
)
if self.only_signal:
if (
self.only_signal not in hashes_grouped_by_prefix.keys()
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
and None not in hashes_grouped_by_prefix.keys()
):
raise CommandError(
f"Error: SignalType '{self.only_signal} was provided, but inferred more from provided hashes."
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
f"Inferred signal types: {', '.join(s_type.get_name() for s_type in hashes_grouped_by_prefix.keys() if s_type)}"
)
if (
len(signal_types) > 1
and len(hashes_grouped_by_prefix) == 1
and None in hashes_grouped_by_prefix.keys()
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
):
raise CommandError(
f"Error: '{self.content_type.get_name()}' supports more than one SignalType"
"No prefix applied to the hashes, cannot infer correct SignalType"
)
# As well as the above validations, also need to combine the None prefixes into the correct SignalType
if None in hashes_grouped_by_prefix.keys():
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
values = set().union(*hashes_grouped_by_prefix.values())
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
keys = list(hashes_grouped_by_prefix.keys())
keys.remove(None)
# Based on the validations, we know that there will only be one key here or one defined in settings
hashes_grouped_by_prefix.clear()
if not self.only_signal:
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
key = signal_types[0]
if len(keys) > 0:
key = keys[0]
else:
key = self.only_signal
hashes_grouped_by_prefix[key] = values


def _match_file(
path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex
Expand All @@ -225,23 +274,70 @@ def _match_file(
return index.query(s_type.hash_from_file(path))


def _group_hashes_by_prefix(
path: pathlib.Path,
settings: CLISettings,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: unused - also, I prefer to avoid passing this around because it is somewhat of a god object

hashes_grouped_by_prefix: t.Dict[t.Optional[SignalType], t.Set[str]],
) -> None:
for line in path.read_text().splitlines():
line = line.strip()
if not line:
continue
components = line.split()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm torn over what to do when len(component) > 2 - this must mean that there are spaces in the hash, which some future type could allow, but it means our naive parsing here will fail oddly.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely agree, and had a similar thought -- on re-evaluation I'm thinking for when len(component) > 1 it's probably better to assume that [0] is the prefix and concat [1:] into the hash.

The only issue here is that if there is a future hash with a space, without a prefix the parsing will fail still. I'll have a think on this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your intuition is right - if there are hashes with a space, providing the prefix allows for a non-ambiguous parse. Check out lpartition or split(limit=2)

signal_type = None
if len(components) > 1:
# Assume it has a prefix
possible_type = components[0]
hash = components[1].strip()
try:
signal_type = settings.get_signal_type(possible_type)
hash = signal_type.validate_signal_str(hash)
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
except KeyError:
logging.exception("Signal type '%s' is invalid", possible_type)
raise CommandError(
f"Error attempting to infer Signal Type: '{possible_type}' is not a valid Signal Type.",
2,
)
except Exception as e:
logging.exception(
"%s failed verification on %s", signal_type.get_name(), hash
)
hash_repr = repr(hash)
if len(hash_repr) > 50:
hash_repr = hash_repr[:47] + "..."
raise CommandError(
f"{hash} from {path} is not a valid hash for {signal_type.get_name()}",
2,
)
else:
# Assume it doesn't have a prefix and is a raw hash
hash = components[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking: What about len(component) > 2?

# We can't validate it this point as we have no context on which signal type
hashes = hashes_grouped_by_prefix.get(signal_type, set())
hashes.add(hash)
hashes_grouped_by_prefix[signal_type] = hashes
Comment on lines +335 to +337
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use setdefault again here to simplify (some folks also like defaultdict, but I have burned before!)

hashes_grouped_by_prefix.setdefault(signal_type, set()).add(hash)  # you are done


def _match_hashes(
path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex
hashes: t.Set[str],
s_type: t.Type[SignalType],
index: SignalTypeIndex,
) -> t.Sequence[IndexMatch]:
ret: t.List[IndexMatch] = []
for hash in path.read_text().splitlines():
for hash in hashes:
hash = hash.strip()
if not hash:
continue
try:
# Need to keep this final validation as we are yet to have validated the hashes without a prefix
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, forgot to remove this additional validation

hash = s_type.validate_signal_str(hash)
except Exception:
logging.exception("%s failed verification on %s", s_type.get_name(), hash)
hash_repr = repr(hash)
if len(hash_repr) > 50:
hash_repr = hash_repr[:47] + "..."
raise CommandError(
f"{hash_repr} from {path} is not a valid hash for {s_type.get_name()}",
f"{hash_repr} is not a valid hash for {s_type.get_name()}",
2,
)
ret.extend(index.query(hash))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tempfile
from threatexchange.cli.tests.e2e_test_helper import ThreatExchangeCLIE2eTest
from threatexchange.signal_type.md5 import VideoMD5Signal
from threatexchange.signal_type.pdq import PdqSignal


class MatchCommandTest(ThreatExchangeCLIE2eTest):
Expand All @@ -29,5 +30,61 @@ def test_invalid_hash(self):
not_hash = "this is not an md5"
self.assert_cli_usage_error(
("-H", "video", "--", not_hash),
f"{not_hash!r} from .* is not a valid hash for video_md5",
f"Error attempting to infer Signal Type: '{not_hash.split()[0]}' is not a valid Signal Type.",
)

def test_valid_hash_with_prefix(self):
hash = "pdq " + PdqSignal.get_examples()[0]
self.assert_cli_output(
("-H", "photo", "--", hash), "pdq 16 (Sample Signals) INVESTIGATION_SEED"
)

def test_no_prefix_specific_signal_type(self):
hash = PdqSignal.get_examples()[0]
self.assert_cli_output(
("-H", "-S", "pdq", "photo", "--", hash),
"pdq 16 (Sample Signals) INVESTIGATION_SEED",
)

def test_multiple_prefixes(self):
hash1 = "pdq " + PdqSignal.get_examples()[0]
hash2 = "pdq " + PdqSignal.get_examples()[1]
with tempfile.NamedTemporaryFile("a+") as fp:
fp.write(hash1 + "\n")
fp.write(hash2)
fp.seek(0)
# CLI is currently showing only one match for multiple hashes
# TODO Improve the handling of multiple hashes in one match query
self.assert_cli_output(
("-H", "photo", fp.name), "pdq 16 (Sample Signals) INVESTIGATION_SEED"
)

def test_incorrect_valid_and_no_prefixes(self):
fakeprefix = "fakesignal"
hash1 = "pdq " + PdqSignal.get_examples()[0]
hash2 = fakeprefix + " " + PdqSignal.get_examples()[1]
hash3 = fakeprefix + " " + PdqSignal.get_examples()[2]
with tempfile.NamedTemporaryFile("a+") as fp:
fp.write(hash1 + "\n")
fp.write(hash2 + "\n")
fp.write(hash3)
fp.seek(0)
self.assert_cli_usage_error(
("-H", "photo", fp.name),
f"Error attempting to infer Signal Type: '{fakeprefix}' is not a valid Signal Type.",
)

def test_prefix_and_no_prefixes(self):
hash1 = "pdq " + PdqSignal.get_examples()[0]
hash2 = "pdq " + PdqSignal.get_examples()[1]
hash3 = PdqSignal.get_examples()[1]
with tempfile.NamedTemporaryFile("a+") as fp:
fp.write(hash1 + "\n")
fp.write(hash2 + "\n")
fp.write(hash3)
fp.seek(0)
# CLI is currently showing only one match for multiple hashes
# TODO Improve the handling of multiple hashes in one match query
self.assert_cli_output(
("-H", "photo", fp.name), "pdq 16 (Sample Signals) INVESTIGATION_SEED"
)