Module ragability.ragability_cc_wc1
Module for the CLI to convert the wiki-contradiction corpus to ragability input format
Functions
def get_args()
-
Expand source code
def get_args(): """ Get the command line arguments """ parser = argparse.ArgumentParser(description='Convert the WikiContradict-based datast to ragability input format') parser.add_argument('--input', '-i', type=str, help='Input TSV file', required=False) parser.add_argument('--output', '-o', type=str, help='Output hjson,json file', required=False) parser.add_argument('--maxn', '-n', type=int, help='Maximum number of input rows to process', required=False) parser.add_argument('--promptfile', '-p', type=str, help='Promptfile to write with the default prompts (do not write)', required=False) parser.add_argument("--debug", "-d", action="store_true", help="Debug mode", required=False) args_tmp = parser.parse_args() args = {} args.update(vars(args_tmp)) return args
Get the command line arguments
def main()
-
Expand source code
def main(): args = get_args() if args["debug"]: set_logging_level(DEBUG) ppargs = pp_config(args) logger.debug(f"Effective arguments: {ppargs}") run(args)
def row2raga_ctx1(row)
-
Expand source code
def row2raga_ctx1(row): out = dict( qid=row["contradiction_ID"] + "-" + "ctx1" + VAR, tags="kind_1context, kind_1context_q, kind_context1, kind_context1_q, answerable", facts=row["context_1"], query=row["query_text"], pids=["q_one_context"], checks=[ dict( cid="answer_correct", query="", func="affirmative", metrics=["correct_answer_all", "correct_answer_answerable"], pid="check_correct_answer", check_for=row["answer_context1"] ), ], ) return out
def row2raga_ctx12ic(row)
-
Expand source code
def row2raga_ctx12ic(row): out = dict( qid=row["contradiction_ID"] + "-" + "ctx12ic" + VAR, tags="kind_2contexts, kind_2contexts_ic, kind_context1+2, kind_context1+2_ic, answerable", facts=[row["context_1"],row["context_2"]], query="", pids=["ci_two_contexts"], checks=[ dict( cid="answer_correct", func="negative", metrics=["correct_answer_all", "contradiction_identification"], ), ], ) return out
def row2raga_ctx12q(row)
-
Expand source code
def row2raga_ctx12q(row): out = dict( qid=row["contradiction_ID"] + "-" + "ctx12q" + VAR, tags="kind_2contexts, kind_2contexts_q, kind_context1+2, kind_context1+2_q, not_answerable", facts=[row["context_1"], row["context_2"]], query=row["query_text"], pids=["q_two_contexts"], checks=[ dict( cid="2ctx_not_answerable", query="", func="affirmative", metrics=["correct_answer_all", "refusal_not_answerable"], pid="check_response_answerable", ), ], ) return out
def row2raga_ctx1ic(row)
-
Expand source code
def row2raga_ctx1ic(row): out = dict( qid=row["contradiction_ID"] + "-" + "ctx1ic" + VAR, tags="kind_1context, kind_1context_ic, kind_context1, kind_context1_ic, answerable", facts=row["context_1"], query="", pids=["ci_one_context"], checks=[ dict( cid="answer_correct", func="negative", metrics=["correct_answer_all", "contradiction_identification"], ), ], ) return out
def row2raga_ctx2(row)
-
Expand source code
def row2raga_ctx2(row): out = dict( qid=row["contradiction_ID"] + "-" + "ctx2" + VAR, tags="kind_1context, kind_1context_q, kind_context2, kind_context2_q, answerable", facts=row["context_2"], query=row["query_text"], pids=["q_one_context"], checks=[ dict( cid="answer_correct", query="", func="affirmative", metrics=["correct_answer_all", "correct_answer_answerable"], pid="check_correct_answer", check_for=row["answer_context2"] ), ], ) return out
def row2raga_ctx21ic(row)
-
Expand source code
def row2raga_ctx21ic(row): out = dict( qid=row["contradiction_ID"] + "-" + "ctx21ic" + VAR, tags="kind_2contexts, kind_2contexts_ic, kind_context2+1, kind_context2+1_ic, answerable", facts=[row["context_2"],row["context_1"]], query="", pids=["ci_two_contexts"], checks=[ dict( cid="answer_correct", func="negative", metrics=["correct_answer_all", "contradiction_identification"], ), ], ) return out
def row2raga_ctx21q(row)
-
Expand source code
def row2raga_ctx21q(row): out = dict( qid=row["contradiction_ID"] + "-" + "ctx21q" + VAR, tags="kind_2contexts, kind_2contexts_q, kind_context2+1, kind_context2+1_q, not_answerable", facts=[row["context_1"], row["context_2"]], query=row["query_text"], pids=["q_two_contexts"], checks=[ dict( cid="2ctx_not_answerable", query="", func="affirmative", metrics=["correct_answer_all", "refusal_not_answerable"], pid="check_response_answerable", ), ], ) return out
def row2raga_ctx2ic(row)
-
Expand source code
def row2raga_ctx2ic(row): out = dict( qid=row["contradiction_ID"] + "-" + "ctx2ic" + VAR, tags="kind_1context, kind_1context_ic, kind_context2, kind_context2_ic, answerable", facts=row["context_2"], query="", pids=["ci_two_contexts"], checks=[ dict( cid="answer_correct", func="negative", metrics=["correct_answer_all", "contradiction_identification"], ), ], ) return out
def row2raga_nc(row)
-
Expand source code
def row2raga_nc(row): out = dict( qid=row["contradiction_ID"] + "-" + "nc" + VAR, tags="kind_no_context, kind_no_context_q, not_answerable", query=row["query_text"], pids=["q_no_context"], checks=[ dict( cid="no_ctx_not_answerable", query="", func="affirmative", metrics=["correct_answer_all", "refusal_not_answerable"], pid="check_response_answerable", ), ], ) return out
def run(config: dict)
-
Expand source code
def run(config: dict): pfile = config.get("promptfile") if pfile: with open(pfile, "wt") as outfp: hjson.dump(PROMPTS, outfp) logger.info(f"Prompts written to {pfile}") if not config.get("input"): logger.info("No input file given, exiting") return if not config.get("output"): logger.error("Output file must be given if input file is specified") return df = pd.read_csv(config["input"], sep="\t", dtype="string", na_filter=False) logger.info(f"Read {len(df)} rows with {df.shape[1]} columns from {config['input']}") # write either a jsonl or json file, depending on the file extension if not config['output'].endswith(".json") and not config['output'].endswith(".jsonl") and not config['output'].endswith(".hjson"): raise Exception(f"Error: Output file must end in .json, .jsonl or .hjson, not {config['output']}") with open(config['output'], 'wt') as f: if config['output'].endswith(".json") or config['output'].endswith(".hjson"): f.write("[\n") colnames = ["index"] + list(df.columns) n_rows = 0 n_inputs = 0 for row in df.itertuples(name=None): n_inputs += 1 rowdict = {cn: cv for cn, cv in zip(colnames, row)} for conv in CONVS: crow = conv(rowdict) # copy over the meta data fields for field in ["reasoning_required", "notes", "WikiContradict_ID"]: if field in rowdict: crow[field] = rowdict[field] n_rows += 1 if config['output'].endswith(".json"): f.write(json.dumps(crow, indent=2) + "\n") elif config['output'].endswith(".hjson"): f.write(hjson.dumps(crow, indent=2) + "\n") else: f.write(json.dumps(crow) + "\n") if config.get("maxn") and n_inputs >= config["maxn"]: logger.info(f"Processed {n_inputs} rows, stopping") break if config['output'].endswith(".json") or config['output'].endswith(".hjson"): f.write("]\n") logger.info(f"Written {n_rows} entries to {config['output']}")