diff --git a/dogesec_commons/stixifier/models.py b/dogesec_commons/stixifier/models.py index 1d35ac4..eb34ab6 100644 --- a/dogesec_commons/stixifier/models.py +++ b/dogesec_commons/stixifier/models.py @@ -28,11 +28,13 @@ class Profile(models.Model): created = models.DateTimeField(auto_now_add=True) name = models.CharField(max_length=250, unique=True) extractions = ArrayField(base_field=models.CharField(max_length=256, validators=[partial(validate_extractor, ["ai", "pattern", "lookup"])]), help_text="extraction id(s)") - whitelists = ArrayField(base_field=models.CharField(max_length=256, validators=[partial(validate_extractor, ["whitelist"])]), help_text="whitelist id(s)", default=list) - aliases = ArrayField(base_field=models.CharField(max_length=256, validators=[partial(validate_extractor, ["alias"])]), help_text="alias id(s)", default=list) + whitelists = ArrayField(base_field=models.CharField(max_length=256, validators=[partial(validate_extractor, ["whitelist"])]), help_text="whitelist id(s)", default=list, blank=True) + aliases = ArrayField(base_field=models.CharField(max_length=256, validators=[partial(validate_extractor, ["alias"])]), help_text="alias id(s)", default=list, blank=True) relationship_mode = models.CharField(choices=RelationshipMode.choices, max_length=20, default=RelationshipMode.STANDARD) extract_text_from_image = models.BooleanField(default=False) defang = models.BooleanField(help_text='If the text should be defanged before processing') + ai_settings_relationships = models.CharField(max_length=256, blank=False, null=True) + ai_settings_extractions = ArrayField(base_field=models.CharField(max_length=256), default=list) class Meta: app_label = settings.APP_LABEL diff --git a/dogesec_commons/stixifier/serializers.py b/dogesec_commons/stixifier/serializers.py index b21dc22..f6b3e56 100644 --- a/dogesec_commons/stixifier/serializers.py +++ b/dogesec_commons/stixifier/serializers.py @@ -1,3 +1,4 @@ +import argparse from rest_framework import serializers from . import conf @@ -7,24 +8,66 @@ import txt2stix.txt2stix from urllib.parse import urljoin from django.conf import settings +from django.contrib.postgres.fields import ArrayField +from rest_framework.validators import ValidationError from drf_spectacular.utils import OpenApiResponse, OpenApiExample from drf_spectacular.utils import OpenApiResponse, OpenApiExample +from django.db import models + class ErrorSerializer(serializers.Serializer): message = serializers.CharField(required=True) code = serializers.IntegerField(required=True) details = serializers.DictField(required=False) +def validate_model(model): + if not model: + return None + try: + extractor = txt2stix.txt2stix.parse_model(model) + except BaseException as e: + raise ValidationError(str(e)) + return model + +def uses_ai(slugs): + extractors = txt2stix.extractions.parse_extraction_config( + txt2stix.txt2stix.INCLUDES_PATH + ) + ai_based_extractors = [] + for slug in slugs: + if extractors[slug].type == 'ai': + ai_based_extractors.append(slug) + + if ai_based_extractors: + raise ValidationError(f'AI based extractors `{ai_based_extractors}` used when `ai_settings_extractions` is not configured') + class ProfileSerializer(serializers.ModelSerializer): id = serializers.UUIDField(read_only=True) + ai_settings_relationships = serializers.CharField( + validators=[validate_model], + help_text='(required if AI relationship enabled): passed in format `provider:model`. Can only pass one model at this time.', + allow_null=True, + required=False, + ) + ai_settings_extractions = serializers.ListField( + child=serializers.CharField(max_length=256, validators=[validate_model]), + help_text='(required if AI extractions enabled) passed in format provider[:model] e.g. openai:gpt4o. Can pass more than one value to get extractions from multiple providers. model part is optional', + ) class Meta: model = Profile fields = "__all__" + def validate(self, attrs): + if attrs['relationship_mode'] == 'ai' and not attrs['ai_settings_relationships']: + raise ValidationError('AI `relationship_mode` requires a valid `ai_settings_relationships`') + if not attrs['ai_settings_extractions']: + uses_ai(attrs['extractions']) + return super().validate(attrs) + DEFAULT_400_ERROR = OpenApiResponse( diff --git a/dogesec_commons/stixifier/stixifier.py b/dogesec_commons/stixifier/stixifier.py index 3ce9ade..b1e1760 100644 --- a/dogesec_commons/stixifier/stixifier.py +++ b/dogesec_commons/stixifier/stixifier.py @@ -1,6 +1,7 @@ import io import json import logging +import os from pathlib import Path import shutil import uuid @@ -12,7 +13,7 @@ from file2txt.converter import get_parser_class from txt2stix import txt2stix from txt2stix.stix import txt2stixBundler -from txt2stix.ai_session import GenericAIExtractor +from txt2stix.ai_extractor import BaseAIExtractor from stix2arango.stix2arango import Stix2Arango from django.conf import settings @@ -107,14 +108,16 @@ def txt2stix(self): aliased_input = txt2stix.aliases.transform_all(aliases.values(), input_text) bundler.whitelisted_values = txt2stix.lookups.merge_whitelists(whitelists.values()) - ai_extractor_session = GenericAIExtractor.openai() - all_extracts = txt2stix.extract_all(bundler, extractors_map, aliased_input, ai_extractor=ai_extractor_session) + + ai_extractors = [txt2stix.parse_model(model_str) for model_str in self.profile.ai_settings_extractions] + txt2stix.validate_token_count(settings.INPUT_TOKEN_LIMIT, aliased_input, ai_extractors) + + all_extracts = txt2stix.extract_all(bundler, extractors_map, aliased_input, ai_extractors=ai_extractors) if self.profile.relationship_mode == models.RelationshipMode.AI and sum(map(lambda x: len(x), all_extracts.values())): - txt2stix.extract_relationships_with_ai(bundler, aliased_input, all_extracts, ai_extractor_session) - - if ai_extractor_session.initialized: - (self.tmpdir/f"conversation_{self.report_id}.md").write_text(ai_extractor_session.get_conversation()) + ai_ref_extractor = txt2stix.parse_model(self.profile.ai_settings_relationships) + txt2stix.validate_token_count(settings.INPUT_TOKEN_LIMIT, aliased_input, [ai_ref_extractor]) + txt2stix.extract_relationships_with_ai(bundler, aliased_input, all_extracts, ai_ref_extractor) return bundler diff --git a/pyproject.toml b/pyproject.toml index ed62048..a65fbc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "dogesec_commons" -version = "0.0.1b0" +version = "0.0.1b1" authors = [ { name="DOGESEC", email="noreply@dogesec.com" }, ]