gabrielaltay's picture
update
ac2020e
raw
history blame
2.2 kB
"""Form components for configuration in LegisQA"""
import streamlit as st
from legisqa_local.config.models import PROVIDER_MODELS, CONGRESS_NUMBERS, SPONSOR_PARTIES
def get_generative_config(key_prefix: str) -> dict:
"""Render generative model configuration form"""
output = {}
key = "provider"
output[key] = st.selectbox(
label=key, options=PROVIDER_MODELS.keys(), key=f"{key_prefix}|{key}"
)
key = "model_name"
output[key] = st.selectbox(
label=key,
options=PROVIDER_MODELS[output["provider"]],
key=f"{key_prefix}|{key}",
)
key = "temperature"
output[key] = st.slider(
key,
min_value=0.0,
max_value=2.0,
value=0.0,
key=f"{key_prefix}|{key}",
)
key = "max_output_tokens"
output[key] = st.slider(
key,
min_value=8192,
max_value=16_384,
key=f"{key_prefix}|{key}",
)
key = "should_escape_markdown"
output[key] = st.checkbox(
key,
value=False,
key=f"{key_prefix}|{key}",
)
key = "should_add_legis_urls"
output[key] = st.checkbox(
key,
value=True,
key=f"{key_prefix}|{key}",
)
return output
def get_retrieval_config(key_prefix: str) -> dict:
"""Render retrieval configuration form"""
output = {}
key = "n_ret_docs"
output[key] = st.slider(
"Number of chunks to retrieve",
min_value=1,
max_value=32,
value=8,
key=f"{key_prefix}|{key}",
)
key = "filter_legis_id"
output[key] = st.text_input("Bill ID (e.g. 118-s-2293)", key=f"{key_prefix}|{key}")
key = "filter_bioguide_id"
output[key] = st.text_input("Bioguide ID (e.g. R000595)", key=f"{key_prefix}|{key}")
key = "filter_congress_nums"
output[key] = st.multiselect(
"Congress Numbers",
CONGRESS_NUMBERS,
default=CONGRESS_NUMBERS[-2:],
key=f"{key_prefix}|{key}",
)
key = "filter_sponsor_parties"
output[key] = st.multiselect(
"Sponsor Party",
SPONSOR_PARTIES,
default=SPONSOR_PARTIES,
key=f"{key_prefix}|{key}",
)
return output