o
    S"g                     @  s>  d dl mZ d dlZd dlZd dlZd dlZd dlmZmZ d dl	m	Z	 d dl
mZmZmZ d dlmZ d dlmZ d dlmZ d d	lmZmZmZ d dlZd dlZd d
lmZmZ d dlmZ d dlmZ  d dl!m"Z"m#Z# d dl$m%Z% d dlm&Z& d dl'm(Z( d dlm)Z) d dl*m+Z+ d dl,m-Z- d dl.m/Z/m0Z0 d dl1m2Z2 d dl3m4Z5 d dl6m7Z7m8Z8 d dl9m:Z: d dl;m<Z<m=Z=m>Z> e> rd dl?m@Z@mAZAmBZBmCZCmDZD eEeFZGerd dlHmIZI d dlJmKZK d dlLmMZM G dd de)ZNe2d G d!d" d"eNZOg d#ZPg d$ZQd5d'd(ZRd6d,d-ZSeG d.d/ d/eZTd7d3d4ZUdS )8    )annotationsN)Counterdefaultdict)copy)	dataclassfieldfields)Path)python_versionindent)TYPE_CHECKINGAnyLiteral)CardData	ModelCard)dataset_info)
model_info)
EvalResulteval_results_to_model_index)	yaml_dump)nn)tqdm)TrainerCallback)CodeCarbonCallback)make_markdown_table)TrainerControlTrainerState)
deprecated__version__)StaticEmbeddingTransformer)$SentenceTransformerTrainingArguments)fullnameis_accelerate_availableis_datasets_available)DatasetDatasetDictIterableDatasetIterableDatasetDictValue)SentenceEvaluator)SentenceTransformer)SentenceTransformerTrainerc                      sF   e Zd Zd fddZdddZdddZdddZd ddZ  ZS )!$SentenceTransformerModelCardCallbackdefault_args_dictdict[str, Any]returnNonec                   s   t    || _d S N)super__init__r0   )selfr0   	__class__ b/mnt/skqttb/ctump_chatbot/chatbot/lib/python3.10/site-packages/sentence_transformers/model_card.pyr6   /   s   

z-SentenceTransformerModelCardCallback.__init__argsr#   stater   controlr   modelr-   trainerr.   c                 K  sF  ddl m}m}m}	 |jd dd |jjD }
|
r!|
d |j_|j	r3|j
|j	|jj|jd|j_|jrE|j
|j|jj|jd|j_t|jtrSt|j }n|jg}d}|t|k r|| }t||	||fr{t|dr{|j|vr{||j |d	7 }|t|k s_|j| |jjs|jp|j	 }r|j| d S d S d S )
Nr   )AdaptiveLayerLossMatryoshka2dLossMatryoshkaLossgenerated_from_trainerc                 S  s   g | ]	}t |tr|qS r:   )
isinstancer   ).0callbackr:   r:   r;   
<listcomp>A   s
    
zDSentenceTransformerModelCardCallback.on_init_end.<locals>.<listcomp>trainevalloss   )sentence_transformers.lossesrA   rB   rC   model_card_dataadd_tagscallback_handler	callbackscode_carbon_callbacktrain_datasetextract_dataset_metadatatrain_datasetsrK   eval_dataseteval_datasetsrE   dictlistvalueslenhasattrappend
set_losseswidgetset_widget_examples)r7   r<   r=   r>   r?   r@   kwargsrA   rB   rC   rQ   lossesloss_idxrK   datasetr:   r:   r;   on_init_end3   sB   	

z0SentenceTransformerModelCardCallback.on_init_endc                   sJ   h d |  } fdd| D |j_ fdd| D |j_d S )N>   do_evaldo_testdo_trainrun_name	hub_token	report_to
eval_delay
eval_steps
output_dir
save_stepslogging_dirlogging_stepssave_strategylogging_strategysave_total_limitgreater_is_betterpush_to_hub_tokensamples_per_labelshow_progress_barlogging_first_stepevaluation_strategymetric_for_best_modelc                   s   i | ]\}}| vr||qS r:   r:   rF   keyvalue)ignore_keysr:   r;   
<dictcomp>   s    zGSentenceTransformerModelCardCallback.on_train_begin.<locals>.<dictcomp>c                   s6   i | ]\}}| vr|j v r|j | kr||qS r:   )r0   r|   r   r7   r:   r;   r      s
     )to_dictitemsrN   all_hyperparametersnon_default_hyperparameters)r7   r<   r=   r>   r?   ra   	args_dictr:   r   r;   on_train_begini   s   

z3SentenceTransformerModelCardCallback.on_train_beginmetricsdict[str, float]c                   s    fdd D }t |dkrd|v rd|d i}|jjr3|jjd d |jkr3|jjd | d S |jj|j|jd| d S )	Nc                   s4   i | ]}| d rd|ddd  | qS )_loss _rL   N)endswithjoinsplitrF   r}   r   r:   r;   r      s   4 zDSentenceTransformerModelCardCallback.on_evaluate.<locals>.<dictcomp>rL   rK   Validation LossStepEpochr   )r[   rN   training_logsglobal_stepupdater]   epoch)r7   r<   r=   r>   r?   r   ra   	loss_dictr:   r   r;   on_evaluate   s   	z0SentenceTransformerModelCardCallback.on_evaluatelogsc                 K  sv   dht |@ }|r9|jjr&|jjd d |jkr&||  |jjd d< d S |jj|j|j||  d d S d S )NrK   r   r   Training Loss)r   r   r   )setrN   r   r   popr]   r   )r7   r<   r=   r>   r?   r   ra   keysr:   r:   r;   on_log   s   	
z+SentenceTransformerModelCardCallback.on_log)r0   r1   r2   r3   )r<   r#   r=   r   r>   r   r?   r-   r@   r.   r2   r3   )
r<   r#   r=   r   r>   r   r?   r-   r2   r3   )r<   r#   r=   r   r>   r   r?   r-   r   r   r2   r3   )r<   r#   r=   r   r>   r   r?   r-   r   r   r2   r3   )	__name__
__module____qualname__r6   re   r   r   r   __classcell__r:   r:   r8   r;   r/   .   s    

6
*r/   zThe `ModelCardCallback` has been renamed to `SentenceTransformerModelCardCallback` and the former is now deprecated. Please use `SentenceTransformerModelCardCallback` instead.c                      s   e Zd Z fddZ  ZS )ModelCardCallbackc                   s   t  j|i | d S r4   )r5   r6   )r7   r<   ra   r8   r:   r;   r6      s   zModelCardCallback.__init__)r   r   r   r6   r   r:   r:   r8   r;   r      s    r   )languagelicenselibrary_nametagsdatasetsr   pipeline_tagr_   model-indexco2_eq_emissions
base_model)r?   r@   eval_results_dictr2   r1   c                  C  s`   t  ttjtjd} t rddlm} || d< t r$ddlm} || d< ddl	m} || d< | S )N)pythonsentence_transformerstransformerstorchr   r   
accelerater   
tokenizers)
r
   sentence_transformers_versionr   r    r   r%   r   r&   r   r   )versionsaccelerate_versiondatasets_versiontokenizers_versionr:   r:   r;   get_versions   s   r   r~   float | int | strr   c                 C  s   t | tr
t| dS | S )N   )rE   floatroundr~   r:   r:   r;   
format_log   s   

r   c                   @  s4  e Zd ZU dZeedZded< dZded< dZ	ded< dZ
ded	< eedZd
ed< eedZd
ed< dZded< edd dZded< dZded< edddZded< edddZded< eeddZded< eeddZded< eeddZded < eeddZd!ed"< eeddZd
ed#< edddZded$< eeddZd
ed%< edddZd&ed'< eeddZd(ed)< edddZd*ed+< eeddd,Zd-ed.< ed/ddZd0ed1< ed2ddZd3ed4< ed5ddZ ded6< ed7ddZ!ded8< ee"ddZ#d(ed9< ee$e%j&d: ddZ'd;ed<< edddd=Z(d>ed?< ddBdCZ)	/dddGdHZ*ddKdLZ+ddNdOZ,ddRdSZ-	TdddYdZZ.dd\d]Z/ddd_d`Z0ddcddZ1ddidjZ2ddodpZ3ddrdsZ4ddtduZ5dddwdxZ6ddydzZ7dd{d|Z8dd}d~Z9dddZ:dddZ;dd Z<dddZ=dddZ>dddZ?ddddZ@dS ) SentenceTransformerModelCardDataa  A dataclass storing data used in the model card.

    Args:
        language (`Optional[Union[str, List[str]]]`): The model language, either a string or a list,
            e.g. "en" or ["en", "de", "nl"]
        license (`Optional[str]`): The license of the model, e.g. "apache-2.0", "mit",
            or "cc-by-nc-sa-4.0"
        model_name (`Optional[str]`): The pretty name of the model, e.g. "SentenceTransformer based on microsoft/mpnet-base".
        model_id (`Optional[str]`): The model ID when pushing the model to the Hub,
            e.g. "tomaarsen/sbert-mpnet-base-allnli".
        train_datasets (`List[Dict[str, str]]`): A list of the names and/or Hugging Face dataset IDs of the training datasets.
            e.g. [{"name": "SNLI", "id": "stanfordnlp/snli"}, {"name": "MultiNLI", "id": "nyu-mll/multi_nli"}, {"name": "STSB"}]
        eval_datasets (`List[Dict[str, str]]`): A list of the names and/or Hugging Face dataset IDs of the evaluation datasets.
            e.g. [{"name": "SNLI", "id": "stanfordnlp/snli"}, {"id": "mteb/stsbenchmark-sts"}]
        task_name (`str`): The human-readable task the model is trained on,
            e.g. "semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more".
        tags (`Optional[List[str]]`): A list of tags for the model,
            e.g. ["sentence-transformers", "sentence-similarity", "feature-extraction"].

    .. tip::

        Install `codecarbon <https://github.com/mlco2/codecarbon>`_ to automatically track carbon emission usage and
        include it in your model cards.

    Example::

        >>> model = SentenceTransformer(
        ...     "microsoft/mpnet-base",
        ...     model_card_data=SentenceTransformerModelCardData(
        ...         model_id="tomaarsen/sbert-mpnet-base-allnli",
        ...         train_datasets=[{"name": "SNLI", "id": "stanfordnlp/snli"}, {"name": "MultiNLI", "id": "nyu-mll/multi_nli"}],
        ...         eval_datasets=[{"name": "SNLI", "id": "stanfordnlp/snli"}, {"name": "MultiNLI", "id": "nyu-mll/multi_nli"}],
        ...         license="apache-2.0",
        ...         language="en",
        ...     ),
        ... )
    )default_factoryzstr | list[str] | Noner   N
str | Noner   
model_namemodel_idlist[dict[str, str]]rU   rW   zjsemantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and morestr	task_namec                   C  s   g dS )N)sentence-transformerssentence-similarityzfeature-extractionr:   r:   r:   r:   r;   <lambda>.      z)SentenceTransformerModelCardData.<lambda>zlist[str] | Noner   r   zLiteral['deprecated']generate_widget_examplesF)defaultinitr   base_model_revision)r   r   r1   r   r   z.dict[SentenceEvaluator, dict[str, Any]] | Noner   zlist[dict[str, float]]r   r_   predict_examplelabel_example_listzCodeCarbonCallback | NonerR   dict[str, str]	citationsz
int | Nonebest_model_step)r   r   repr	list[str]r   Tbool
first_saver   intwidget_stepr   r   r   r   versionzmodel_card_template.mdr	   template_path)r   r   r   zSentenceTransformer | Noner?   r2   r3   c                 C  s~   | j  }t| j tr| j g| _ | j| j|d| _| j| j|d| _| jr;| jddkr=t	d| jd d | _d S d S d S )N)infer_languages/rL   zThe provided z} model ID should include the organization or user, such as "tomaarsen/mpnet-base-nli-matryoshka". Setting `model_id` to None.)
r   rE   r   validate_datasetsrU   rW   r   countloggerwarning)r7   r   r:   r:   r;   __post_init__R  s   

z.SentenceTransformerModelCardData.__post_init__dataset_listlist[dict[str, Any]]r   c              	   C  s   g }|D ]r}d|vrd|v r|d |d< d|v rqzt |d }W n ty7   td|d d |d= Y n:w |jrd|rdd|jv rd|jd}|d urdt|trT|g}|D ]}|| jvrc| j	| qV|j
| jvrq| j	|j
 |	| q|S )NnameidzThe dataset `id` z5 does not exist on the Hub. Setting the `id` to None.r   )get_dataset_info	Exceptionr   r   cardDatagetrE   r   r   r]   r   r   )r7   r   r   output_dataset_listrd   infodataset_languager   r:   r:   r;   r   b  s6   


z2SentenceTransformerModelCardData.validate_datasetsrb   list[nn.Module]c              	     s   ddi}|D ]}z	|j ||jj< W q ty   Y qw tt}| D ]\}}|| | q#ddd  fd	d
| D | _| 	dd dd
 |D D  d S )NzSentence Transformersa  
@inproceedings{reimers-2019-sentence-bert,
    title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
    author = "Reimers, Nils and Gurevych, Iryna",
    booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
    month = "11",
    year = "2019",
    publisher = "Association for Computational Linguistics",
    url = "https://arxiv.org/abs/1908.10084",
}
rb   r   r2   r   c                 S  s2   t | dkrd| d d d | d  S | d S )NrL   z, r   z and r   )r[   r   )rb   r:   r:   r;   	join_list  s   z>SentenceTransformerModelCardData.set_losses.<locals>.join_listc                   s   i | ]	\}} ||qS r:   r:   )rF   citationrb   r   r:   r;   r         z?SentenceTransformerModelCardData.set_losses.<locals>.<dictcomp>c                 S  s   g | ]}d | qS )zloss:r:   rF   rK   r:   r:   r;   rH         z?SentenceTransformerModelCardData.set_losses.<locals>.<listcomp>c                 S  s   i | ]}|j j|qS r:   )r9   r   r   r:   r:   r;   r     r   )rb   r   r2   r   )
r   r9   r   r   r   rY   r   r]   r   rO   )r7   rb   r   rK   inverted_citationsr   r:   r   r;   r^     s   
"z+SentenceTransformerModelCardData.set_lossesstepc                 C  
   || _ d S r4   )r   )r7   r   r:   r:   r;   set_best_model_step     
z4SentenceTransformerModelCardData.set_best_model_steprd   Dataset | DatasetDictc              	   C  s  t |ttfr	d S t |trt|d}g | _ttjt	|
 dd}d}t| ddddD ]\}}t || tr;q/d	d
 || j D }|| |}t|}|dkrWq/i }	t|tjt|t||dD ]\}
}tdd | D |	|
< qjtt|	 dd d \}}|d | t	||d  d d d }}|D ]e}
dd
 ||
  D }t|dk r|r| }dd
 ||  D }t|dkr|| n||d  t|dk r|st|dk rq| j|d tj|dd  t|d dd |d d | _qq/d S )N)rd      )k  zComputing widget examplesexampleF)descunitleavec                 S  s2   g | ]\}}t |tst |tr|jd v r|qS )>   stringlarge_string)rE   rX   r+   dtype)rF   columnfeaturer:   r:   r;   rH     s    
zHSentenceTransformerModelCardData.set_widget_examples.<locals>.<listcomp>r   c                 s  s.    | ]\}}|d kr| dst|V  qdS )dataset_name_prompt_lengthN)r   r[   r|   r:   r:   r;   	<genexpr>  s    zGSentenceTransformerModelCardData.set_widget_examples.<locals>.<genexpr>c                 S  s   | d S )NrL   r:   )xr:   r:   r;   r     r   zFSentenceTransformerModelCardData.set_widget_examples.<locals>.<lambda>r}   r   c                 S  &   g | ]\}}|d kr| ds|qS r  r  r   rF   r}   sentencer:   r:   r;   rH     
    r   c                 S  r  r  r  r  r:   r:   r;   rH     r  rL   )source_sentence	sentences   )rE   r)   r*   r'   r(   r_   r   randomchoicesrY   r   r   r   featuresselect_columnsr[   	enumerateselectsamplerangeminsumzipsortedr   extendr]   r   )r7   rd   dataset_namesnum_samples_to_checkr  num_samplescolumnsstr_datasetdataset_sizelengthsidxr   indicesr   target_indicesbackup_indicesr  
backup_idxbackup_sampler:   r:   r;   r`     sb   

(

&z4SentenceTransformerModelCardData.set_widget_examplesr   	evaluatorr,   r   r   c                   s   ddl m} t|| j|< t|dr[|j  r]t||r%dd |jD  nt tr- g  fdd|	 D }| j
rN| j
d d	 |krN| j
d | d S | j
||d
| d S d S d S )Nr   )SequentialEvaluatorprimary_metricc                 S  s   g | ]}|j qS r:   r6  )rF   sub_evaluatorr:   r:   r;   rH     s    zKSentenceTransformerModelCardData.set_evaluation_metrics.<locals>.<listcomp>c                   s   i | ]\}}| v r||qS r:   r:   r|   primary_metricsr:   r;   r          zKSentenceTransformerModelCardData.set_evaluation_metrics.<locals>.<dictcomp>r   r   r   ) sentence_transformers.evaluationr5  r   r   r\   r6  rE   
evaluatorsr   r   r   r   r]   )r7   r4  r   r   r   r5  training_log_metricsr:   r9  r;   set_evaluation_metrics  s$   

z7SentenceTransformerModelCardData.set_evaluation_metricsr'   c                   s   d}t t}t }|D ]1}|d }|d }||vr3|| dt| d t|| |kr3|| t| jkr< nq fdd| D  _	d S )Nr  textlabelz<li>z</li>c                   sF   g | ]\}} j jrt|tr j j| n|d d| d dqS )z<ul> z</ul>)LabelExamples)r?   labelsrE   r   r   )rF   rA  example_setr7   r:   r;   rH     s     zGSentenceTransformerModelCardData.set_label_examples.<locals>.<listcomp>)
r   rY   r   r]   r   r[   addnum_classesr   r   )r7   rd   num_examples_per_labelexamplesfinished_labelsr   r@  rA  r:   rG  r;   set_label_examples  s    

z3SentenceTransformerModelCardData.set_label_examplesr  c                   s   t |tr fdd| D S |rtd|rd }|p|jjt|jd}|jj	r;|j|jj	v r;|jj	|j j
|d< |j }rwt| d }|drwd|v rw|tdd  d}|d |d	< |d
 dd  }rwt|dkrw||d< |gS )Nc                   s(   g | ]\}} j ||d D ]}|qqS ))r  )infer_datasets)rF   r  sub_datasetrd   rG  r:   r;   rH   $  s    zCSentenceTransformerModelCardData.infer_datasets.<locals>.<listcomp>z_dataset_\d+)r   r   sizer   zhf://datasets/@r   rL   r   (   revision)rE   r(   r   rematchr   r  r   r   splitsnum_examplesdownload_checksumsrY   r   
startswithr[   )r7   rd   r  dataset_output	checksumssourcesource_partsrS  r:   rG  r;   rN  "  s&   



"z/SentenceTransformerModelCardData.infer_datasetsr@  str | list[str]c                 C  s   | j |S r4   )r?   tokenize)r7   r@  r:   r:   r;   r_  A  s   z)SentenceTransformerModelCardData.tokenize Dataset | IterableDataset | Noner   rK   'dict[str, nn.Module] | nn.Module | Nonec                   s(  |si S t |trt||d< dd |jD |d< i |d< t |tr|jD ]}|dd | d }t |tr| }t |trTd	|v rT|d	 jd
d }d}n	dd D }d}dt	t
|d d| t	t|t| d d| t	t|d d| dd|d |< q&t |ttfrt d fddt D d|d |< q&t |trdt	t
dt	tt dt	tddd|d |< q&t |trtdd D  t d
krddt| did|d |< q&dt
  dt t  ddt  ddd|d |< q&t|i d|d |< q&d<ddd d!id"d |d  D d d#ifd$d|d  D g}	tt|	d%d&d'|d(< |dd) |d*< t|d* t|d* d  }
g }t|
D ]Y}i }|jD ]K}|d* | | }t |trt|d+krt|dd+ dd, d- }t |trt|dkr|dd d. }t|d/d0}d1| d2||< q{|| qttt|d%d&d'|d3< d4t|i|d5< t|d6r| }z	tj|d7d8}W n ty   t|}Y nw td9| d:d'|d5 d;< |S )=a  
        Given a dataset, compute the following:
        * Dataset Size
        * Dataset Columns
        * Dataset Stats
            - Strings: min, mean, max word count/token length
            - Integers: Counter() instance
            - Floats: min, mean, max range
            - List: number of elements or min, mean, max number of elements
        * 3 Example samples
        * Loss function name
            - Loss function config
        rP  c                 S  s   g | ]}d | dqS )<code></code>r:   )rF   r
  r:   r:   r;   rH   ]      zLSentenceTransformerModelCardData.compute_dataset_metrics.<locals>.<listcomp>r*  statsNr  r   attention_maskrL   )dimtokensc                 S     g | ]}t |qS r:   r[   )rF   r  r:   r:   r;   rH   i      
charactersr     r   )r"  meanmax)r	  datar   c                   s8   i | ]}|t  d krdnd  | t  dqS )rL   ~rB  z.2%rj  r   )counter
subsectionr:   r;   r   w  s    *zLSentenceTransformerModelCardData.compute_dataset_metrics.<locals>.<dictcomp>r   c                 S  ri  r:   rj  )rF   lstr:   r:   r;   rH     rk  rY   z	 elementsz.2frp  rX   c                 S  s    dd dd |  D  d S )Nz<ul><li>z	</li><li>c                 s  s"    | ]\}}| d | V  qdS )z: Nr:   r|   r:   r:   r;   r    s     zaSentenceTransformerModelCardData.compute_dataset_metrics.<locals>.to_html_list.<locals>.<genexpr>z
</li></ul>)r   r   rp  r:   r:   r;   to_html_list  s    zNSentenceTransformerModelCardData.compute_dataset_metrics.<locals>.to_html_listrB  typec                 S  s   i | ]	\}}||d  qS )r	  r:   r|   r:   r:   r;   r     r   detailsc                   s   i | ]\}}| |d  qS ru  r:   r|   )rv  r:   r;   r     r;  -:|--|  stats_tabler  rK  r   r   z, ...]z...
z<br>rb  rc  examples_tabler$   rK   get_config_dictr   r   ```json

```config_code)rp  rX   )rE   r'   r[   column_namesr   r_  rX   r#  tolistr   r"  ro  r   r   r   r%  r   rY   r$   r   r   r   replacer!  r]   r\   r  jsondumps	TypeError)r7   rd   r   rK   r
  first	tokenizedr-  suffixstats_linesr)  examples_lines
sample_idxr*  r~   config
str_configr:   )rr  rs  rv  r;   compute_dataset_metricsD  s   




	
 
z8SentenceTransformerModelCardData.compute_dataset_metricsdataset_metadata nn.Module | dict[str, nn.Module]dataset_typeLiteral['train', 'eval']c              	     s   |rV|r-t |trt|t|kst |tr-t|dkr-td| d| d| d g }|s4|}t |trL fddt| |	 |D }n

||d  g}|d	krmtd
d |D }|rmd|  |S )NrL   zThe number of `z?_datasets` in the model card data does not match the number of z1 datasets in the Trainer. Removing the provided `z$_datasets` from the model card data.c              	     s2   g | ]\}}} ||t tr | n qS r:   )r  rE   rX   )rF   r  dataset_valuer   rK   r7   r:   r;   rH     s    zMSentenceTransformerModelCardData.extract_dataset_metadata.<locals>.<listcomp>r   rI   c                 S  s   g | ]}| d dqS )rP  r   )r   )rF   metadatar:   r:   r;   rH     rd  zdataset_size:)rE   r(   r[   r'   r   r   rN  r$  r   rZ   r  r#  rO   r   )r7   rd   r  rK   r  num_training_samplesr:   r  r;   rT     s8   


z9SentenceTransformerModelCardData.extract_dataset_metadatar-   c                 C  r   r4   )r?   )r7   r?   r:   r:   r;   register_model  r   z/SentenceTransformerModelCardData.register_modelc                 C  r   r4   )r   )r7   r   r:   r:   r;   set_model_id  r   z-SentenceTransformerModelCardData.set_model_idrS  c                 C  sJ   zt |}W n
 ty   Y dS w |j| _|d u s|dkr |j}|| _dS )NFmainT)get_model_infor   r   r   shar   )r7   r   rS  r   r:   r:   r;   set_base_model  s   z/SentenceTransformerModelCardData.set_base_modelc                 C  s   t |tr|g}|| _d S r4   )rE   r   r   )r7   r   r:   r:   r;   set_language  s   

z-SentenceTransformerModelCardData.set_languagec                 C  r   r4   )r   )r7   r   r:   r:   r;   set_license  r   z,SentenceTransformerModelCardData.set_licensec                 C  s4   t |tr|g}|D ]}|| jvr| j| q
d S r4   )rE   r   r   r]   )r7   r   tagr:   r:   r;   rO     s   

z)SentenceTransformerModelCardData.add_tagsc                   s   t | jd trD| jd jjj}t|}d|jdd  g}|j	
d | fddtdt D 7 }|D ]
}| |rA d S q7d S t | jd tr]| jd jr_| | jd j d S d S d S )Nr   r   r   c                   s4   g | ]}d   d| d d   |d  qS )r   Nr   )r   )rF   r.  rV  r:   r;   rH     s    (zJSentenceTransformerModelCardData.try_to_set_base_model.<locals>.<listcomp>rL   )rE   r?   r"   
auto_modelr  _name_or_pathr	   r   partsr   r   r!  r[   r  r!   r   )r7   r   base_model_pathcandidate_model_idsr   r:   r  r;   try_to_set_base_model  s$   
z6SentenceTransformerModelCardData.try_to_set_base_modelc              	     s  g }i }g }| j  D ]\}}t|ddt|ddrHtfdd| D rHfdd| D }rHd rHtd	 d d)ddfdd| D }fdd| D }|jt|dd d}t|dr|	  }rz	t
j|dd}	W n ty   t|}	Y nw td|	 dd}|t| ||d fdd| fdd| D  || qg }
|D ]~}dd |d D }t|}|
D ]g}tdd |d D }|d  |d  kr?||kr?|d! |d! kr?|d" |d" kr?|d D ]}d#|v r|d#||d! < ||d$  ||d! < q	t|d! ts4|d! g|d!< |d! |d!   nq|
| q|
D ]}t|dd%d&|d'< qH|
t| t| j|d(S )*au  Format the evaluation metrics for the model card.

        The following keys will be returned:
        - eval_metrics: A list of dictionaries containing the class name, description, dataset name, and a markdown table
          This is used to display the evaluation metrics in the model card.
        - metrics: A list of all metric keys. This is used in the model card metadata.
        - model-index: A list of dictionaries containing the task name, task type, dataset type, dataset name, metric name,
          metric type, and metric value. This is used to display the evaluation metrics in the model card metadata.
        r   Nr6  c                 3  s    | ]
}|  d  V  qdS )r   N)rY  r   r   r:   r;   r  4  s    zGSentenceTransformerModelCardData.format_eval_metrics.<locals>.<genexpr>c                   s&   i | ]\}}|t  d  d |qS )rL   Nrj  r|   r  r:   r;   r   5  s   & zHSentenceTransformerModelCardData.format_eval_metrics.<locals>.<dictcomp>r   rL   r~   r   r2   c                 S  s0   zt | dr|  W S W | S  ty   Y | S w )z^Try to convert a value from a Numpy or Torch scalar to pure Python, if not already pure Pythonr	  )r\   itemr   r   r:   r:   r;   try_to_pure_python9  s   

zPSentenceTransformerModelCardData.format_eval_metrics.<locals>.try_to_pure_pythonc                   s   i | ]	\}}| |qS r:   r:   r|   )r  r:   r;   r   B  r   c                   sJ   g | ]!\}}| krd | d n|| krd t | d nt |dqS )**Metricr+   )r   rF   
metric_keymetric_valuer7  r:   r;   rH   D  s    zHSentenceTransformerModelCardData.format_eval_metrics.<locals>.<listcomp>rB  r  r   r   r  r  r{  )
class_namedescriptionr  table_linesr  c                   sD   zt | W S  ty   Y nw t| tr d| v r  |  d S d S )Nr   r   )r   r   rE   r   r   )r  )try_to_floatr:   r;   r  c  s   
zJSentenceTransformerModelCardData.format_eval_metrics.<locals>.try_to_floatc                   sj   g | ]1\}}| d urt  dd pd r& ddddnd|dd |dqS )Nr   -unknownr   Unknown)r   	task_typer  r  metric_namemetric_typer  )r   lowerr  titler  )r  r  metric_value_floatr  r:   r;   rH   o  s    

c                 S  s   i | ]	}|d  |d qS r  r:   rF   liner:   r:   r;   r     r   r  c                 s  s    | ]}|d  V  qdS )r  Nr:   r  r:   r:   r;   r    s    r  r  r  r+   r  ry  rz  table)eval_metricsr   r   )r~   r   r2   r   )r   r   getattrallr   rY  r[   r  r\   r  r  r  r  r   r   r]   r$   r&  r   r   r   rE   rY   r   r  r   r   )r7   r  all_metricseval_resultsr4  r   r  r  r  r  grouped_eval_metricseval_metriceval_metric_mappingeval_metric_metricsgrouped_eval_metricgrouped_eval_metric_metricsr  r:   )r  r  r  r   r6  r  r  r;   format_eval_metrics$  s   

	







z4SentenceTransformerModelCardData.format_eval_metricsc                   sv   g  j D ]}| D ]}| vr | qqd fdd}t |dfddj D }t|}|d	|v d
S )Nr}   r   r2   c                   sL   | dkrdS | dkrdS | dkrdS | dkrdS |  d	rd
S  | d S )Nr   r   r   rL   r   rm  r   r  rK   r   r   )r   indexr  )eval_lines_keysr:   r;   sort_metrics  s   
zKSentenceTransformerModelCardData.format_training_logs.<locals>.sort_metricsr  c                   s    g | ]  fd dD qS )c                   sH   i | ] }| d  j krd| v rt | nd dn |dqS )r   r  r  )r   r   r   r   )r  r7   r:   r;   r     s     
zTSentenceTransformerModelCardData.format_training_logs.<locals>.<listcomp>.<dictcomp>r:   )rF   )r7   sorted_eval_lines_keys)r  r;   rH     s    zISentenceTransformerModelCardData.format_training_logs.<locals>.<listcomp>r  )
eval_linesexplain_bold_in_eval)r}   r   r2   r   )r   r   r]   r%  r   )r7   linesr}   r  r   r  r:   )r  r7   r  r;   format_training_logs  s    

	z5SentenceTransformerModelCardData.format_training_logs1dict[Literal['co2_eq_emissions'], dict[str, Any]]c                 C  sd   | j j }dt|jd t|jdd|jdk|j|jt	|j
d ddi}|jr0|j|d d	< |S )
Nr   r  
codecarbonzfine-tuningYi  r  )	emissionsenergy_consumedr\  training_typeon_cloud	cpu_modelram_total_size
hours_usedhardware_used)rR   tracker_prepare_emissions_datar   r  r  r  r  r  r   duration	gpu_model)r7   emissions_dataresultsr:   r:   r;   get_codecarbon_data  s   z4SentenceTransformerModelCardData.get_codecarbon_datac                 C  sV   d}| j jrddddd| j j| j jdd }| j  | j  t| j |dS )	NzCosine SimilarityzDot ProductzEuclidean DistancezManhattan Distance)cosinedot	euclidean	manhattanr   r   )model_max_lengthoutput_dimensionalitymodel_stringsimilarity_fn_name)r?   r  r   r  r  get_max_seq_length get_sentence_embedding_dimensionr   )r7   r  r:   r:   r;   get_model_specific_metadata  s   z<SentenceTransformerModelCardData.get_model_specific_metadatac              
     sr   j r jsz   W n	 ty   Y nw  js/ jr) jjj d j  _n jjj _ fddt D } j	r]z	|
   W n ty\ } z
td|  |d }~ww  jrz	|
   W n ty } ztd|  W Y d }~nd }~ww t jdk|d<  jr jjr jjjd ur|
   |
   d _ tD ]}||d  q|S )	Nz
 based on c                   s   i | ]
}|j t |j qS r:   )r   r  )rF   r   rG  r:   r;   r     s    z<SentenceTransformerModelCardData.to_dict.<locals>.<dictcomp>z+Error while formatting evaluation metrics: z&Error while formatting training logs: d   hide_eval_linesF)r   r   r  r   r   r?   r9   r   r   r   r   r  r   r   r   r  r[   rR   r  _start_timer  r  IGNORED_FIELDSr   )r7   
super_dictexcr}   r:   rG  r;   r     sL   z(SentenceTransformerModelCardData.to_dictc                 C  s$   t dd |   D d|d S )Nc                 S  s*   i | ]\}}|t v r|d g fvr||qS r4   )YAML_FIELDSr|   r:   r:   r;   r   "  s   * z<SentenceTransformerModelCardData.to_yaml.<locals>.<dictcomp>F)	sort_keys
line_break)r   r   r   strip)r7   r  r:   r:   r;   to_yaml   s   z(SentenceTransformerModelCardData.to_yaml)r2   r3   )T)r   r   r   r   r2   r   )rb   r   r2   r3   )r   r   r2   r3   )rd   r   r2   r3   )r   r   )
r4  r,   r   r1   r   r   r   r   r2   r3   )rd   r'   r2   r3   r4   )rd   r   r  r   r2   r   )r@  r^  r2   r1   )rd   r`  r   r1   rK   ra  r2   r   )
rd   r   r  r   rK   r  r  r  r2   r   )r?   r-   r2   r3   )r   r   r2   r3   )r   r   rS  r   r2   r3   )r   r^  r2   r3   )r   r   r2   r3   )r   r^  r2   r3   r2   r1   )r2   r  )r2   r   )Ar   r   r   __doc__r   rY   r   __annotations__r   r   r   rU   rW   r   r   r   r   r   rX   r   r   r   r   r_   r   r   rR   r   r   r   r   r   r   r   r   r   r	   __file__parentr   r?   r   r   r^   r   r`   r?  rM  rN  r_  r  rT   r  r  r  r  r  rO   r  r  r  r  r  r   r  r:   r:   r:   r;   r      sz   
 '


#

L



~
+





 
&

2r   r?   r-   r   c                 C  s   t j| j| jjdd}|jS )Nu   🤗)	card_datar   hf_emoji)r   from_templaterN   r   content)r?   
model_cardr:   r:   r;   generate_model_card(  s   r
  r   )r~   r   r2   r   )r?   r-   r2   r   )V
__future__r   r  loggingr  rT  collectionsr   r   r   dataclassesr   r   r   pathlibr	   platformr
   textwrapr   typingr   r   r   r   r   huggingface_hubr   r   r   r   r   r  huggingface_hub.repocard_datar   r   huggingface_hub.utilsr   r   tqdm.autonotebookr   r   transformers.integrationsr   transformers.modelcardr   transformers.trainer_callbackr   r   typing_extensionsr   r   r    r   sentence_transformers.modelsr!   r"   #sentence_transformers.training_argsr#   sentence_transformers.utilr$   r%   r&   r   r'   r(   r)   r*   r+   	getLoggerr   r   2sentence_transformers.evaluation.SentenceEvaluatorr,   )sentence_transformers.SentenceTransformerr-   sentence_transformers.trainerr.   r/   r   r  r  r   r   r   r
  r:   r:   r:   r;   <module>   sp    
 

      2