Module ragability.ragability_query

Module to run a bunch of ragability queries and get the responses.

Functions

def get_args()
Expand source code
def get_args():
    """
    Get the command line arguments
    """
    parser = argparse.ArgumentParser(description='Run a bunch of ragability queries and get the responses')
    parser.add_argument('--input', '-i', type=str, help='Input file with the facts and queries (or from config), jsonl, json, yaml', required=False)
    parser.add_argument('--output', '-o', type=str, help='Output file with the responses (default: $DATETIME.out.jsonl), jsonl, json, yaml', required=False)
    parser.add_argument("--config", "-c", type=str, help="Config file with the LLM and other info for an experiment, json, jsonl, yaml", required=False)
    parser.add_argument('--llms', '-l', nargs="*", type=str, default=[], help='LLMs to use for the queries (or use config)', required=False)
    parser.add_argument("--promptfile", "-pf", type=str, help="File with the prompt to use for the queries (or use config), jsonl, json, yaml", required=False)
    parser.add_argument("--dry-run", "-n", action="store_true", help="Dry run, do not actually run the queries", required=False)
    parser.add_argument("--all", "-a", action="store_true", help="Run all queries, even if they have a response", required=False)
    parser.add_argument("--logfile", "-f", type=str, help="Log file", required=False)
    parser.add_argument("--debug", "-d", action="store_true", help="Debug mode", required=False)
    parser.add_argument("--verbose", "-v", action="store_true", help="Be more verbose and inform what is happening", required=False)
    args_tmp = parser.parse_args()
    for llm in args_tmp.llms:
        if not re.match(r"^[a-zA-Z0-9_\-./]+/.+$", llm):
            raise Exception(f"Error: 'llm' field must be in the format 'provider/model' in line: {llm}")
    # convert the argparse object to a dictionary
    tmp = {}
    tmp.update(vars(args_tmp))
    args: dict = tmp

    # if a config file is specified, read the config file using our config reading function and update the arguments.
    # The config data may contain:
    # - input: used only if not specified in the command line arguments
    # - output: used only if not specified in the command line arguments
    # - llms: added to the ones specified in the command line arguments
    # - prompt: used to add config info to the llms specified in the command line arguments
    if args["config"]:
        config = read_config_file(args["config"], update=False)
        # merge the args into the config, giving precedence to the args, except for the LLM list, which is merged
        # by adding the args to the config
        oldllms = config.get("llms", [])
        config.update(args)
        # add the llms from the args to the llms from the config, but only if the llms is not already in the config
        mentionedllm = [llm if isinstance(llm, str) else llm["llm"] for llm in config["llms"]]
        for llm in args["llms"]:
            if llm not in mentionedllm:
                oldllms.append(llm)
        config["llms"] = oldllms
        args = config
    # make sure we got the input file, prompt and llm arguments, if not, show an error message, then the
    # argparse help message and exit
    # also, we need the prompt file. If we have both, an error message is shown.
    # If we have a prompt, put it into a list as the only element, otherwise use the read_prompt_file function
    # to read the prompt file into a list
    if not args["input"]:
        parser.print_help()
        raise Exception("Error: Missing input file")
    if not args["llms"]:
        parser.print_help()
        raise Exception("Error: Missing llms argument")
    if not args["promptfile"]:
        parser.print_help()
        raise Exception("Error: Must specify promptfile")
    args["prompt"] = read_prompt_file(args["promptfile"])
    # update the llm configuration in the args dict
    update_llm_config(args)
    return args

Get the command line arguments

def main()
Expand source code
def main():
    args = get_args()
    if args["logfile"]:
        add_logging_file(args["logfile"])
    if args["debug"]:
        set_logging_level(DEBUG)
        ppargs = pp_config(args)
        logger.debug(f"Effective arguments: {ppargs}")
    run(args)
def run(config: dict)
Expand source code
def run(config: dict):
    # read the input file into memory, we do not expect it to be too large and we want to check the format
    # of all json lines
    inputs = read_input_file(config["input"])
    llms = LLMS(config)
    llmnames = llms.list_aliases()
    logger.info(f"LLMs to use: {llmnames}")
    logger.info(f"Loaded {len(inputs)} queries from {config['input']}")
    logger.info(f"Got {len(config['prompt'])} prompts")

    # for each LLM, for each prompt, for each query, query the LLM and write the response to the output file
    # if the output file is not specified, use the current date and time to create a default output file
    # if the output file exists, overwrite it

    if not config['output']:
        config['output'] = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.out.jsonl"
    # write either a jsonl or json file, depending on the file extension
    if not config['output'].endswith(".json") and not config['output'].endswith(".jsonl") and not config['output'].endswith(".hjson"):
        raise Exception(f"Error: Output file must end in .json, .jsonl or .hjson, not {config['output']}")
    prompt_idx = {p["pid"]: idx for idx, p in enumerate(config["prompt"])}
    n_llm_errors = 0
    n_outputs = 0
    with open(config['output'], 'w') as f:
        if config['output'].endswith(".json") or config['output'].endswith(".hjson"):
            f.write("[\n")
        for llmname in llmnames:
            for query in inputs:
                # if we have the field "pid" in the query, we use just that prompt, otherwise,
                # if we have the field "pids" in the query, we iterate over those prompts, otherwise we use all prompts
                # defined in the config
                # NOTE: the field "pid" is put into the output of a processed query, which makes it possible to
                # reprocess the output file with the same prompt
                if "pid" in query:
                    pids = [query["pid"]]
                elif "pids" in query:
                    pids = query["pids"]
                else:
                    pids = prompt_idx.keys()
                for pid in pids:
                    prompt_tmpl = config["prompt"][prompt_idx[pid]]
                    logger.debug(f"Processing prompt {pid}")
                    # if the response is already in the query and there is no or an empty error,
                    # skip unless the option --all is give
                    if not config['all'] and query.get("response") is not None and not query.get("error"):
                        logger.debug(f"Skipping query {query['qid']} with response")
                        continue
                    prompt = prompt_tmpl.copy()
                    logger.debug(f"Processing query {query['qid']}")
                    if query.get("response"):     # we already have a response, skip
                        logger.debug(f"Skipping query {query['qid']} with response")
                        continue
                    # replace facts and query variables in the prompt
                    facts = query.get("facts")
                    if facts is None:
                        pass
                    elif facts == []:
                        facts = None
                    elif isinstance(facts, str):
                        facts = [facts]
                    logger.debug(f"Got facts list {facts}")
                    for role, content in prompt.items():
                        if role in ROLES:
                            if facts is not None:
                                facttmpl = prompt.get("fact")
                                if facttmpl:
                                    factsfmt = []
                                    for idx, fact in enumerate(facts):
                                        factsfmt.append(facttmpl.replace("${fact}", fact).replace("${n}", str(idx+1)))
                                    factsfmt = "".join(factsfmt)
                                else:
                                    factsfmt = "\n".join(facts)
                                prompt[role] = content.replace("${facts}", factsfmt)
                            prompt[role] = prompt[role].replace("${query}", query["query"])
                    messages = llms.make_messages(prompt=prompt)
                    if config['dry_run']:
                        logger.info(f"Would query LLM {llmname} with prompt {prompt['pid']} for query {query['qid']}")
                        logger.debug(f"Messages: {messages}")
                        response = ""
                        error = "NOT RUN: DRY-RUN"
                    else:
                        if config['verbose']:
                            logger.info(f"Querying LLM {llmname} with prompt {prompt['pid']} for query {query['qid']}")
                        else:
                            logger.debug(f"Querying LLM {llmname} with prompt {prompt['pid']} for query {query['qid']}")
                        logger.debug(f"Messages: {messages}")
                        ret = llms.query(llmname, messages=messages, return_cost=True, debug=config['debug'])
                        response = ret.get("answer", "")
                        error = ret.get("error", "")
                        if error:
                            logger.warning(f"Error querying LLM {llmname} with prompt {prompt['pid']} for query {query['qid']}: {error}")
                            n_llm_errors += 1

                    towrite = query.copy()
                    towrite["response"] = response
                    towrite["error"] = error
                    towrite["pid"] = pid
                    towrite["llm"] = llmname
                    n_outputs += 1
                    if config['output'].endswith(".json"):
                        f.write(json.dumps(towrite, indent=2) + "\n")
                    elif config['output'].endswith(".hjson"):
                        f.write(hjson.dumps(towrite, indent=2) + "\n")
                    else:
                        f.write(json.dumps(towrite) + "\n")
        if config['output'].endswith(".json") or config['output'].endswith(".hjson"):
            f.write("]\n")

    logger.info(f"Wrote {n_outputs} entries to {config['output']}, {n_llm_errors} LLM errors")